/*____________________________________________________________________________
	Copyright (C) 1997 Network Associates Inc. and affiliated companies.
	All rights reserved.
	
	
	
	$Id: CSocket.cp,v 1.5 2000/08/06 04:13:00 jason Exp $
____________________________________________________________________________*/

#include <Gestalt.h>

#include "pgpMem.h"
#include "pgpEncode.h"

#include "CInternetUtilities.h"
#include "CUDPSocket.h"
#include "CTCPSocket.h"
#include "CSocket.h"

Boolean							CSocket::sInitialized = false;
CInternetUtilities *			CSocket::sInternetUtilities = nil;
set<CSocket *>					CSocket::sSocketsSet;
map<ThreadID, SSocketsThreadContext>	CSocket::sThreadContextMap;

extern ThreadID					gMainThreadID;

	void
CSocket::Initialize()
{
	if (! sInitialized) {
		OSStatus	err = noErr;

		err = ::InitOpenTransport();
		if (err != kOTNoError) {
			ThrowPGPError_(err);
		}
		sInternetUtilities = new CInternetUtilities;
		if (sInternetUtilities == nil) {
			ThrowPGPError_(kPGPError_OutOfMemory);
		}
		sInitialized = true;
	}
}

	

	void
CSocket::CleanupSockets()
{
	sSocketsSet.clear();

	sThreadContextMap.clear();
	delete sInternetUtilities;
	sInternetUtilities = 0;
	::CloseOpenTransport();
	sInitialized = false;
}


	Boolean
CSocket::VerifyPGPSocketRef(
	CSocket *	inSocketRef)
{
	return (sSocketsSet.find(inSocketRef) != sSocketsSet.end());
}



	void
CSocket::CreateThreadStorage(
	SSocketsThreadContext **	outContext)
{
	OSStatus									err;
	ThreadID									theThreadID;
	map<ThreadID, SSocketsThreadContext>::iterator		theIterator;
	SSocketsThreadContext *							theResult = nil;

	try {
		err = ::GetCurrentThread(&theThreadID);
		if (err != noErr) {
			ThrowPGPError_(err);
		}
		theIterator = sThreadContextMap.find(theThreadID);
		if (theIterator != sThreadContextMap.end()) {
			theResult = new SSocketsThreadContext;
			if (theResult == 0) {
				ThrowPGPError_(kPGPError_OutOfMemory);
			}
			*theResult = (*theIterator).second;
		} else {
			SSocketsThreadContext	tempContext;

			pgpClearMemory( &tempContext, sizeof(SSocketsThreadContext) );
			pgpLeaksSuspend();
			sThreadContextMap[theThreadID] = tempContext;
			pgpLeaksResume();
			
			theIterator = sThreadContextMap.find(theThreadID);
			if (theIterator == sThreadContextMap.end()) {
				ThrowPGPError_(kPGPError_OutOfMemory);
			}
		}
		(*theIterator).second.isBusy = false;
		(*theIterator).second.lastError = kPGPError_NoErr;
		(*theIterator).second.hNameBuffer[0] = 0;
		(*theIterator).second.hAliasesBuffer = 0;
		(*theIterator).second.hostEntry.h_name =
				(*theIterator).second.hNameBuffer;
		(*theIterator).second.hostEntry.unused =
				(*theIterator).second.hAliasesBuffer;
		(*theIterator).second.hostEntry.h_addrtype = kPGPAddressFamilyInternet;
		(*theIterator).second.hostEntry.h_length = sizeof(PGPUInt32);
		(*theIterator).second.hostEntry.h_addr_list = 
				(char **) (*theIterator).second.hAddressesListBuffer;
		(*theIterator).second.idleEventHandler = 0;
		(*theIterator).second.idleEventHandlerData = 0;
	}
	
	catch (...) {
		delete theResult;
		throw;
	}
	
	*outContext = theResult;
}



	void
CSocket::DisposeThreadStorage(
	SSocketsThreadContext *	inContext)
{
	ThreadID	theThreadID;
	OSStatus	err;									
	
	err = ::GetCurrentThread(&theThreadID);
	if (err != noErr) {
		ThrowPGPError_(err);
	}
	if (inContext != 0) {
		sThreadContextMap[theThreadID] = *inContext;
		delete inContext;
	} else {
		map<ThreadID, SSocketsThreadContext>::iterator	theIterator;

		theIterator = sThreadContextMap.find(theThreadID);
		if (theIterator != sThreadContextMap.end()) {
			sThreadContextMap.erase(theThreadID);
		}
	}
		
}



	SSocketsThreadContext *
CSocket::GetThreadContext()
{
	OSStatus									err;
	ThreadID									theThreadID;
	map<ThreadID, SSocketsThreadContext>::iterator		theIterator;
	SSocketsThreadContext *							theResult = nil;
	
	try {
		err = ::GetCurrentThread(&theThreadID);
		if (err != noErr) {
			ThrowPGPError_(err);
		}
		theIterator = sThreadContextMap.find(theThreadID);
		if (theIterator == sThreadContextMap.end()) {
			ThrowPGPError_(kPGPError_SocketsNoStaticStorage);
		}

		theResult = &(*theIterator).second;
	}
	
	catch (...) {
	}
	
	return theResult;
}



	CSocket *
CSocket::CreateSocket(
	SInt32	inType)
{
	CSocket *	theSocket = (CSocket *) kInvalidPGPSocketRef;
	
	// Create a socket of the correct type
	switch (inType) {
		case kPGPSocketTypeStream:
		{
			theSocket = new CTCPSocket;
			if (theSocket == nil) {
				ThrowPGPError_(kPGPError_OutOfMemory);
			}
		}
		break;
		
		
		case kPGPSocketTypeDatagram:
		{
			theSocket = new CUDPSocket;
			if (theSocket == nil) {
				ThrowPGPError_(kPGPError_OutOfMemory);
			}
		}
		break;
		
		
		default:
		{
			ThrowPGPError_(kPGPError_SocketsProtocolNotSupported);
		}
		break;
	}
	
	return theSocket;
}



CSocket::CSocket()
	: mTLSSession(kInvalidPGPtlsSessionRef), mInCallback(false)
{
	pgpLeaksSuspend();
	sSocketsSet.insert(this);
	pgpLeaksResume();
}



CSocket::~CSocket()
{
	sSocketsSet.erase(this);
}



	SInt32
CSocket::Select(
	PGPSocketSet *				ioReadSet,
	PGPSocketSet *				ioWriteSet,
	PGPSocketSet *				ioErrorSet,
	const PGPSocketsTimeValue *	inTimeout)
{
	PGPError			pgpErr;
	UInt32				ticks;
	SInt32				result = 0;
	UInt16				i;
	PGPSocketSet *		readSet = nil;
	PGPSocketSet *		writeSet = nil;
	PGPSocketSet *		errorSet = nil;
	
	try {
		readSet = new PGPSocketSet;
		writeSet = new PGPSocketSet;
		errorSet = new PGPSocketSet;
		if ((readSet == nil) || (writeSet == nil) || (errorSet == nil)) {
			ThrowPGPError_(kPGPError_OutOfMemory);
		}
		
		PGPSOCKETSET_ZERO(readSet);
		PGPSOCKETSET_ZERO(writeSet);
		PGPSOCKETSET_ZERO(errorSet);

		// Set the timeout
		if (inTimeout == 0) {
			ticks = 0xFFFFFFFF;
		} else {
			ticks = ::TickCount() + (60 * inTimeout->tv_sec)
						+ ((1000 * inTimeout->tv_usec) / 60);
		}
		
		// Check the sets
		Boolean	socketClosed = false;

		do {
			if (ioReadSet != 0) {
				for (i = 0; i < ioReadSet->fd_count; i++) {
					if (VerifyPGPSocketRef(
					(CSocket *) ioReadSet->fd_array[i])) {
						if (((CSocket *) ioReadSet->fd_array[i])->
						IsReadable()) {
							PGPSOCKETSET_SET(	ioReadSet->fd_array[i],
												readSet);
							result++;
						}
					} else {
						socketClosed = true;
					}
				}
			}
			if (ioWriteSet != 0) {
				for (i = 0; i < ioWriteSet->fd_count; i++) {
					if (VerifyPGPSocketRef(
					(CSocket *) ioWriteSet->fd_array[i])) {
						if (((CSocket *) ioWriteSet->fd_array[i])->
							IsWriteable()) {
								PGPSOCKETSET_SET(ioWriteSet->fd_array[i],
									writeSet);
								result++;
							}
					} else {
						socketClosed = true;
					}
				}
			}
			if (ioErrorSet != 0) {
				for (i = 0; i < ioErrorSet->fd_count; i++) {
					if (VerifyPGPSocketRef(
					(CSocket *) ioErrorSet->fd_array[i])) {
						if (((CSocket *) ioErrorSet->fd_array[i])->IsError()) {
							PGPSOCKETSET_SET(ioErrorSet->fd_array[i], errorSet);
							result++;
						}
					} else {
						socketClosed = true;
					}
				}
			}
			
			if ((result > 0) || socketClosed) {
				break;
			}
			
			// Give time to callback
			pgpErr = CallIdleEventHandler();
			ThrowIfPGPError_(pgpErr);
		} while (::TickCount() < ticks);
		
		if (ioReadSet != 0) {
			*ioReadSet = *readSet;
		}
		if (ioWriteSet != 0) {
			*ioWriteSet = *writeSet;
		}
		if (ioErrorSet != 0) {
			*ioErrorSet = *errorSet;
		}
		
		delete readSet;
		delete writeSet;
		delete errorSet;
	}
	
	catch (...) {
		delete readSet;
		delete writeSet;
		delete errorSet;
		throw;
	}
	
	return result;
}
	

	void
CSocket::GetSocketOptions(
	SInt32		inOptionName,
	SInt32 *	outOptionValue)
{
	switch (inOptionName) {
		case kPGPSocketOptionAcceptingConnections:
		{
			*outOptionValue = IsListening();
		}
		break;
		
		
		case kPGPSocketOptionType:
		{
			*outOptionValue = mSocketType;
		}
		break;
	}
}



	void
CSocket::SetIdleEventHandler(
	PGPEventHandlerProcPtr	inCallback,
	PGPUserValue			inUserData)
{
	SSocketsThreadContext *	theContext = CSocket::GetThreadContext();
	
	if (theContext == nil) {
		ThrowPGPError_(kPGPError_UnknownError);
	} else {
		theContext->idleEventHandler = inCallback;
		theContext->idleEventHandlerData = inUserData;
	}
}



	void
CSocket::GetIdleEventHandler(
	PGPEventHandlerProcPtr *	outCallback,
	PGPUserValue *				outUserData)
{
	SSocketsThreadContext *	theContext = CSocket::GetThreadContext();

	if (theContext == nil) {
		ThrowPGPError_(kPGPError_UnknownError);
	} else {
		*outCallback = theContext->idleEventHandler;
		*outUserData = theContext->idleEventHandlerData;
	}
}

	PGPError
CSocket::CallIdleEventHandler()
{
	SSocketsThreadContext *	theContext = CSocket::GetThreadContext();
	PGPError	result;
	
	if ((theContext != nil) && (theContext->idleEventHandler != nil)) {
		PGPEvent			theEvent;

		pgpClearMemory(&theEvent, sizeof(theEvent));
		theEvent.type = kPGPEvent_SocketsIdleEvent;
		result = theContext->idleEventHandler(
									nil,
									&theEvent,
									theContext->idleEventHandlerData);
	} else {
		result = kPGPError_NoErr;
	}
	
	return result;
}
